#!/usr/bin/env python3
import os
import numpy as np
import sys
sys.path.append('..')
from collections import namedtuple
from glob import glob
import time
import gc
import pickle
from utils.deeplearningutilities.tf import Trainer, MyCheckpointManager
from evaluate_network_ptport import evaluate
from argoverse.map_representation.map_api import ArgoverseMap
from train_utils import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

os.environ['CUDA_VISIBLE_DEVICES'] = '2, 3'
dataset_path = '~/particle/argoverse/argoverse_forecasting/'
lane_path = '~/particle/TrafficFluids/datasets/'

use_lane = False

## don't forget change scheduler

use_encoder = False

encoder_func = 'lstm_ctsconv'

train_window = 3

use_normalize_input = False
normalize_scale = 3

batch_divide = 1
TrainParams = namedtuple('TrainParams', ['epochs', 'batches_per_epoch', 'base_lr', 'batch_size'])
train_params = TrainParams(100, 10 * batch_divide, 0.001, 16 / batch_divide)

model_name = 'ctsconv_first_10_steps_2'

# checkpoint_path = os.path.join('checkpoint_models', model_name)

# os.system('mkdir ' + checkpoint_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if use_lane:
    from datasets.argoverse_lane_loader import read_pkl_data
    val_path = os.path.join(dataset_path, 'val', 'lane_data')
    train_path = os.path.join(dataset_path, 'train', 'lane_data')
else:
    from datasets.argoverse_pickle_loader import read_pkl_data
    val_path = os.path.join(dataset_path, 'val', 'clean_data')
    train_path = os.path.join(dataset_path, 'train', 'clean_data')

def create_model(encoder_func):
    if use_encoder:
        from models.encoder_model import ParticlesNetwork
        """Returns an instance of the network for training and evaluation"""
        model = ParticlesNetwork(encoder_func = encoder_func, radius_scale = 40)
    else:
        from models.default import ParticlesNetwork
        """Returns an instance of the network for training and evaluation"""
        model = ParticlesNetwork(radius_scale = 40)
    return model


class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes.
    """
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

def main():
    am = ArgoverseMap()

    val_dataset = read_pkl_data(val_path, batch_size=8, shuffle=False, repeat=False)

    dataset = read_pkl_data(train_path, batch_size=train_params.batch_size, repeat=True, shuffle=True)

    data_iter = iter(dataset)   
    
    model_ = create_model(encoder_func)
    # model = model_.to(device)
    model = MyDataParallel(model_).to(device)
    optimizer = torch.optim.Adam(model.parameters(), train_params.base_lr,betas=(0.9, 0.999), weight_decay=4e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.9968) #0.9968
    
    def train_one_batch(model, batch, train_window=2):

        batch_size = train_params.batch_size

        inputs = ([
            batch['pos_2s'], batch['vel_2s'], 
            batch['pos0'], batch['vel0'], 
            batch['accel'], None,
            batch['lane'], batch['lane_norm'], 
            batch['car_mask'], batch['lane_mask']
        ])

        # print_inputs_shape(inputs)
        # print(batch['pos0'])
        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = batch['pos1']
        # print(pr_pos1)

        # losses = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])
        losses = 0.5 * loss_fn(pr_pos1, gt_pos1, torch.sum(batch['car_mask'], dim = -1).unsqueeze(-1) - 1, batch['car_mask'])
        del gt_pos1

        # pos_2s = batch['pos_2s']
        # vel_2s = batch['vel_2s']
        pos0 = batch['pos0']
        vel0 = batch['vel0']
        for i in range(train_window-1):
            pos_enc = torch.unsqueeze(pos0, 2)
            # pos_2s = torch.cat([pos_2s[:,:,1:,:], pos_enc], axis=2)
            vel_enc = torch.unsqueeze(vel0, 2)
            # vel_2s = torch.cat([vel_2s[:,:,1:,:], vel_enc], axis=2)
            # del pos_enc, vel_enc
            inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, batch['accel'], None,
                      batch['lane'],
                      batch['lane_norm'],batch['car_mask'], batch['lane_mask'])
            pos0, vel0 = pr_pos1, pr_vel1
            del pos_enc, vel_enc
            
            pr_pos1, pr_vel1, states = model(inputs, states)

            gt_pos1 = batch['pos'+str(i+2)]
            clean_cache(device)

            # losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
            #                    model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])

            losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
                               torch.sum(batch['car_mask'], dim = -1).unsqueeze(-1) - 1, batch['car_mask'])


            # pr_pos1, pr_vel1 = pr_pos2, pr_vel2
            # print(pr_pos1)


        total_loss = 128 * torch.sum(losses,axis=0) / batch_size


        return total_loss
    
    epochs = train_params.epochs
    batches_per_epoch = train_params.batches_per_epoch   # batchs_per_epoch.  Dataset is too large to run whole data. 
    data_load_times = []  #Per batch 
    train_losses = []
    valid_losses = []
    valid_metrics_list = []
    min_loss = None

    for i in range(epochs):
        epoch_start_time = time.time()

        model.train()
        epoch_train_loss = 0 
        sub_idx = 0

        print("training ... epoch " + str(i + 1), end='')
        for batch_itr in range(batches_per_epoch):

            data_fetch_start = time.time()
            batch = next(data_iter)
            
            if sub_idx == 0:
                optimizer.zero_grad()
                if (batch_itr // batch_divide) % 25 == 0:
                    print("... batch " + str((batch_itr // batch_divide) + 1), end='', flush=True)
            sub_idx += 1

            batch_size = len(batch['pos0'])

            if use_lane:
                pass
            else:
                batch['lane_mask'] = [np.array([0])] * batch_size

            batch_tensor = {}
            convert_keys = (['pos' + str(i) for i in range(train_window + 1)] + 
                            ['vel' + str(i) for i in range(train_window + 1)] + 
                            ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])

            for k in convert_keys:
                batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device)
                
            if use_normalize_input:
                batch_tensor, max_pos = normalize_input(batch_tensor, normalize_scale, train_window)

            for k in ['car_mask', 'lane_mask']:
                batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device)

            for k in ['track_id' + str(i) for i in range(30)] + ['city']:
                batch_tensor[k] = batch[k]

            batch_tensor['car_mask'] = batch_tensor['car_mask'].squeeze(-1)
            accel = torch.zeros(batch_size, 1, 3).to(device)
            batch_tensor['accel'] = accel
            del batch

            data_fetch_latency = time.time() - data_fetch_start
            data_load_times.append(data_fetch_latency)

            current_loss = train_one_batch(model, batch_tensor, train_window=train_window)
            if sub_idx < batch_divide:
                current_loss.backward(retain_graph=True)
            else:
                current_loss.backward()
                optimizer.step()
                sub_idx = 0
            del batch_tensor

            epoch_train_loss += float(current_loss)
            del current_loss
            clean_cache(device)

            if batch_itr == batches_per_epoch - 1:
                print("... DONE", flush=True)

        train_losses.append(epoch_train_loss)

        model.eval()
        with torch.no_grad():
            valid_total_loss, valid_metrics = evaluate(model.module, val_dataset, am=am, 
                                                       train_window=train_window, max_iter=50, 
                                                       device=device, use_lane=use_lane, 
                                                       use_normalize_input=use_normalize_input, 
                                                       normalize_scale=normalize_scale, 
                                                       batch_size=val_dataset.batch_size)

        valid_losses.append(float(valid_total_loss))
        valid_metrics_list.append(valid_metrics)
        
        # torch.save(model.module, os.path.join(checkpoint_path, model_name + '_' + str(i) + ".pth"))

        if min_loss is None:
            min_loss = valid_losses[-1]

        if valid_losses[-1] < min_loss:
            min_loss = valid_losses[-1] 
            best_model = model
            torch.save(model.module, model_name + ".pth")

            #Add evaluation Metrics

        epoch_end_time = time.time()

        print('epoch: {}, train loss: {}, val loss: {}, epoch time: {}, lr: {}, {}'.format(
            i + 1, train_losses[-1], valid_losses[-1], 
            round((epoch_end_time - epoch_start_time) / 60, 5), 
            format(get_lr(optimizer), "5.2e"), model_name
        ))

        scheduler.step()
    
        with open('results/{}_val_metrics.pickle'.format(model_name), 'wb') as f:
            pickle.dump(valid_metrics_list, f)
        

def final_evaluation():
    am = ArgoverseMap()
    
    val_dataset = read_pkl_data(val_path, batch_size=32, shuffle=False, repeat=False)
    
    trained_model = torch.load(model_name + '.pth')
    trained_model.eval()
    
    with torch.no_grad():
        valid_total_loss, valid_metrics = evaluate(trained_model, val_dataset, am=am, 
                                                   train_window=train_window, max_iter=len(val_dataset), 
                                                   device=device, start_iter=50, use_lane=use_lane, 
                                                   use_normalize_input=use_normalize_input, 
                                                   normalize_scale=normalize_scale)
    
    with open('results/{}_predictions.pickle'.format(model_name), 'wb') as f:
        pickle.dump(valid_metrics, f)
        
        
if __name__ == '__main__':
    main()
    
    # final_evaluation()
    
    
    